Hand-Written Digit Classifiation

In this module, we will examine the MNIST dataset, which is a set of 70,000 images of digits handwritten by high school students and employees of the US Census Bureau.

MNIST is considered the “hello-world” of the machine-learning world, and is often a good place to start for understanding classification algorithms.

Let’s load the MNIST dataset.

library(MicrosoftML)
library(tidyverse)
Loading tidyverse: ggplot2
Loading tidyverse: tibble
Loading tidyverse: tidyr
Loading tidyverse: readr
Loading tidyverse: purrr
Loading tidyverse: dplyr
Conflicts with tidy packages ---------------------------------------------------------------
filter(): dplyr, stats
lag():    dplyr, stats
library(magrittr)

Attaching package: ‘magrittr’

The following object is masked from ‘package:purrr’:

    set_names

The following object is masked from ‘package:tidyr’:

    extract
library(dplyrXdf)
theme_set(theme_minimal())
mnist_xdf <- file.path("..", "data", "MNIST.xdf")
mnist_xdf <- RxXdfData(mnist_xdf)

Let’s take a look at the data:

rxGetInfo(mnist_xdf)
File name: /home/alizaidi/learnAnalytics-MicrosoftML/Student-Resources/data/MNIST.xdf 
Number of observations: 70000 
Number of variables: 786 
Number of blocks: 7 
Compression type: zlib 

Our dataset contains 70K records, and 786 columns. There are actually 784 features, because each image in the dataset is a 28x28 pixel image. The two additional columns are for the label, and a column with a pre-sampled train and test split.

Visualizing Digits

Let’s make some visualizations to examine the MNIST data and see what we can use for a classifier to classify the digits.

mnist_df <- rxDataStep(inData = mnist_xdf, outFile = NULL,
                       maxRowsByCols = nrow(mnist_xdf)*ncol(mnist_xdf)) %>% tbl_df

Let’s see the average for each digit:

mnist_df %>% 
  keep(is.numeric) %>% 
  rowMeans() %>% data.frame(intensity = .) %>% 
  tbl_df %>% 
  bind_cols(mnist_df) %T>% print -> mnist_df

Visualize average intensity by label:

ggplot(mnist_df, aes(x = intensity, y = ..density..)) +
  geom_density(aes(fill = Label), alpha = 0.3)

Let’s try a boxplot:

ggplot(mnist_df, aes(x = Label, y = intensity)) +
  geom_boxplot(aes(fill = Label), alpha = 0.3)

Visualize Digits

Let’s plot a sample set of digits:

flip <- function(matrix) {
      apply(matrix, 2, rev)
}
plot_digit <- function(samp) {
  
  digit <- unlist(samp)
  m <- flip(matrix(rev(as.numeric(digit)), nrow = 28))
  image(m, col = grey.colors(255))
  
}
mnist_df[11, ] %>% 
  select(-Label, -intensity, -splitVar) %>% 
  sample_n(1) %>% 
  rowwise() %>% plot_digit

Split the Data into Train and Test Sets

splits <- rxSplit(mnist_xdf,
                  splitByFactor = "splitVar", 
                  overwrite = TRUE)
names(splits) <- c("train", "test")

Let’s first train a softmax classifier using the rxLogisticRegression:

softmax <- estimate_model(xdf_data = splits$train,
                          form = make_form(splits$train, 
                                           resp_var = "Label", 
                                           vars_to_skip = c("splitVar")),
                          model = rxLogisticRegression,
                          type = "multiClass")
Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.
LBFGS multi-threading will attempt to load dataset into memory. In case of out-of-memory issues, turn off multi-threading by setting trainThreads to 1.
Beginning optimization
num vars: 7850
improvement criterion: Mean Improvement
L1 regularization selected 3712 of 7850 weights.
Not training a calibrator because it is not needed.
Elapsed time: 00:00:17.4634647
Elapsed time: 00:00:00.0938839

Let’s see how we did. Let’s examine our results on the train set:

softmax_scores <- rxPredict(modelObject = softmax, 
                            data = splits$test, 
                            outData = tempfile(fileext = ".xdf"),
                            overwrite = TRUE,
                            extraVarsToWrite = "Label")
Elapsed time: 00:00:00.9884308

We can make a confusion matrix of all our results:

rxCube( ~ Label : PredictedLabel , data = softmax_scores,
       returnDataFrame = TRUE) -> softmax_scores_df
softmax_scores_df %>% ggplot(aes(x = Label, y = PredictedLabel,
                                 fill = Counts)) +
  geom_raster() +
  scale_fill_continuous(low = "steelblue2", high = "mediumblue")

Here we are plotting the raw counts. This might unfairly represent the more populated classes. Let’s weight each count by the total number of samples in that class:

label_rates <- softmax_scores_df %>% 
  tbl_df %>% 
  group_by(Label) %>% 
  mutate(rate = Counts/sum(Counts))
label_rates %>% ggplot(aes(x = Label, y = PredictedLabel, fill = rate)) +
  geom_raster() +
  scale_fill_continuous(low = "steelblue2", high = "mediumblue")

Let’s fill out all the correct scores with zeros so we can see the errors more clearly:

label_rates %>%
  mutate(error_rate = ifelse(Label == PredictedLabel,
                             0, rate)) %>% 
  ggplot(aes(x = Label, y = PredictedLabel, fill = error_rate)) +
  geom_raster() +
  scale_fill_continuous(low = "steelblue2", high = "mediumblue",
                        labels = scales::percent)

Exercises

  1. Take a look at David Robinson’s tweet on using a single pixel to distinguish between pairs of digits.
  2. You can find his gist saved in the Rscripts directory.
LS0tCnRpdGxlOiAiQ2xhc3NpZmljYXRpb24gTW9kZWxzIGZvciBDb21wdXRlciBWaXNpb24iCmF1dGhvcjogIkFsaSBaYWlkaSIKZGF0ZTogIjIwMTcvMDYvMDUiCm91dHB1dDogCiAgaHRtbF9ub3RlYm9vazoKICAgIHRvYzogdHJ1ZQogICAgdG9jX2RlcHRoOiAyCiAgICB0b2NfZmxvYXQ6IHRydWUKZWRpdG9yX29wdGlvbnM6IAogIGNodW5rX291dHB1dF90eXBlOiBpbmxpbmUKLS0tCgojIyBIYW5kLVdyaXR0ZW4gRGlnaXQgQ2xhc3NpZmlhdGlvbgoKSW4gdGhpcyBtb2R1bGUsIHdlIHdpbGwgZXhhbWluZSB0aGUgW01OSVNUXShodHRwOi8veWFubi5sZWN1bi5jb20vZXhkYi9tbmlzdC8pIGRhdGFzZXQsIHdoaWNoIGlzIGEgc2V0IG9mIDcwLDAwMCBpbWFnZXMgb2YgZGlnaXRzIGhhbmR3cml0dGVuIGJ5IGhpZ2ggc2Nob29sIHN0dWRlbnRzIGFuZCBlbXBsb3llZXMgb2YgdGhlIFVTIENlbnN1cyBCdXJlYXUuCgpNTklTVCBpcyBjb25zaWRlcmVkIHRoZSAiaGVsbG8td29ybGQiIG9mIHRoZSBtYWNoaW5lLWxlYXJuaW5nIHdvcmxkLCBhbmQgaXMgb2Z0ZW4gYSBnb29kIHBsYWNlIHRvIHN0YXJ0IGZvciB1bmRlcnN0YW5kaW5nIGNsYXNzaWZpY2F0aW9uIGFsZ29yaXRobXMuCgpMZXQncyBsb2FkIHRoZSBNTklTVCBkYXRhc2V0LgoKYGBge3IgbW5pc3RfbG9hZH0KCmxpYnJhcnkoTWljcm9zb2Z0TUwpCmxpYnJhcnkodGlkeXZlcnNlKQpsaWJyYXJ5KG1hZ3JpdHRyKQpsaWJyYXJ5KGRwbHlyWGRmKQp0aGVtZV9zZXQodGhlbWVfbWluaW1hbCgpKQoKbW5pc3RfeGRmIDwtIGZpbGUucGF0aCgiLi4iLCAiZGF0YSIsICJNTklTVC54ZGYiKQptbmlzdF94ZGYgPC0gUnhYZGZEYXRhKG1uaXN0X3hkZikKCmBgYAoKTGV0J3MgdGFrZSBhIGxvb2sgYXQgdGhlIGRhdGE6CgpgYGB7ciBtbmlzdF9zdHJ9CgpyeEdldEluZm8obW5pc3RfeGRmKQoKYGBgCgpPdXIgZGF0YXNldCBjb250YWlucyA3MEsgcmVjb3JkcywgYW5kIDc4NiBjb2x1bW5zLiBUaGVyZSBhcmUgYWN0dWFsbHkgNzg0IGZlYXR1cmVzLCBiZWNhdXNlIGVhY2ggaW1hZ2UgaW4gdGhlIGRhdGFzZXQgaXMgYSAyOHgyOCBwaXhlbCBpbWFnZS4gVGhlIHR3byBhZGRpdGlvbmFsIGNvbHVtbnMgYXJlIGZvciB0aGUgbGFiZWwsIGFuZCBhIGNvbHVtbiB3aXRoIGEgcHJlLXNhbXBsZWQgdHJhaW4gYW5kIHRlc3Qgc3BsaXQuCgojIyBWaXN1YWxpemluZyBEaWdpdHMKCkxldCdzIG1ha2Ugc29tZSB2aXN1YWxpemF0aW9ucyB0byBleGFtaW5lIHRoZSBNTklTVCBkYXRhIGFuZCBzZWUgd2hhdCB3ZSBjYW4gdXNlIGZvciBhIGNsYXNzaWZpZXIgdG8gY2xhc3NpZnkgdGhlIGRpZ2l0cy4KCmBgYHtyIG1uaXN0X2RmfQoKbW5pc3RfZGYgPC0gcnhEYXRhU3RlcChpbkRhdGEgPSBtbmlzdF94ZGYsIG91dEZpbGUgPSBOVUxMLAogICAgICAgICAgICAgICAgICAgICAgIG1heFJvd3NCeUNvbHMgPSBucm93KG1uaXN0X3hkZikqbmNvbChtbmlzdF94ZGYpKSAlPiUgdGJsX2RmCgpgYGAKCkxldCdzIHNlZSB0aGUgYXZlcmFnZSBmb3IgZWFjaCBkaWdpdDoKCmBgYHtyIGFnZ19hY3RpdmF0aW9ufQoKbW5pc3RfZGYgJT4lIAogIGtlZXAoaXMubnVtZXJpYykgJT4lIAogIHJvd01lYW5zKCkgJT4lIGRhdGEuZnJhbWUoaW50ZW5zaXR5ID0gLikgJT4lIAogIHRibF9kZiAlPiUgCiAgYmluZF9jb2xzKG1uaXN0X2RmKSAlVD4lIHByaW50IC0+IG1uaXN0X2RmCgpgYGAKClZpc3VhbGl6ZSBhdmVyYWdlIGludGVuc2l0eSBieSBsYWJlbDoKCmBgYHtyIGRlbnNpdHl9CgpnZ3Bsb3QobW5pc3RfZGYsIGFlcyh4ID0gaW50ZW5zaXR5LCB5ID0gLi5kZW5zaXR5Li4pKSArCiAgZ2VvbV9kZW5zaXR5KGFlcyhmaWxsID0gTGFiZWwpLCBhbHBoYSA9IDAuMykKCmBgYAoKTGV0J3MgdHJ5IGEgYm94cGxvdDoKCmBgYHtyIGJveHBsb3R9CgpnZ3Bsb3QobW5pc3RfZGYsIGFlcyh4ID0gTGFiZWwsIHkgPSBpbnRlbnNpdHkpKSArCiAgZ2VvbV9ib3hwbG90KGFlcyhmaWxsID0gTGFiZWwpLCBhbHBoYSA9IDAuMykKYGBgCgojIyBWaXN1YWxpemUgRGlnaXRzCgpMZXQncyBwbG90IGEgc2FtcGxlIHNldCBvZiBkaWdpdHM6CgoKYGBge3J9CgpmbGlwIDwtIGZ1bmN0aW9uKG1hdHJpeCkgewoKICAgICAgYXBwbHkobWF0cml4LCAyLCByZXYpCn0KCnBsb3RfZGlnaXQgPC0gZnVuY3Rpb24oc2FtcCkgewogIAogIGRpZ2l0IDwtIHVubGlzdChzYW1wKQogIG0gPC0gZmxpcChtYXRyaXgocmV2KGFzLm51bWVyaWMoZGlnaXQpKSwgbnJvdyA9IDI4KSkKICBpbWFnZShtLCBjb2wgPSBncmV5LmNvbG9ycygyNTUpKQogIAp9CgptbmlzdF9kZlsxMSwgXSAlPiUgCiAgc2VsZWN0KC1MYWJlbCwgLWludGVuc2l0eSwgLXNwbGl0VmFyKSAlPiUgCiAgc2FtcGxlX24oMSkgJT4lIAogIHJvd3dpc2UoKSAlPiUgcGxvdF9kaWdpdAoKYGBgCgojIyBTcGxpdCB0aGUgRGF0YSBpbnRvIFRyYWluIGFuZCBUZXN0IFNldHMKCmBgYHtyIHNwbGl0c30KCnNwbGl0cyA8LSByeFNwbGl0KG1uaXN0X3hkZiwKICAgICAgICAgICAgICAgICAgc3BsaXRCeUZhY3RvciA9ICJzcGxpdFZhciIsIAogICAgICAgICAgICAgICAgICBvdmVyd3JpdGUgPSBUUlVFKQpuYW1lcyhzcGxpdHMpIDwtIGMoInRyYWluIiwgInRlc3QiKQoKYGBgCgpMZXQncyBmaXJzdCB0cmFpbiBhIHNvZnRtYXggY2xhc3NpZmllciB1c2luZyB0aGUgYHJ4TG9naXN0aWNSZWdyZXNzaW9uYDoKCmBgYHtyIG11bHRpbm9taWFsfQoKc29mdG1heCA8LSBlc3RpbWF0ZV9tb2RlbCh4ZGZfZGF0YSA9IHNwbGl0cyR0cmFpbiwKICAgICAgICAgICAgICAgICAgICAgICAgICBmb3JtID0gbWFrZV9mb3JtKHNwbGl0cyR0cmFpbiwgCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXNwX3ZhciA9ICJMYWJlbCIsIAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdmFyc190b19za2lwID0gYygic3BsaXRWYXIiKSksCiAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kZWwgPSByeExvZ2lzdGljUmVncmVzc2lvbiwKICAgICAgICAgICAgICAgICAgICAgICAgICB0eXBlID0gIm11bHRpQ2xhc3MiKQoKYGBgCgpMZXQncyBzZWUgaG93IHdlIGRpZC4gTGV0J3MgZXhhbWluZSBvdXIgcmVzdWx0cyBvbiB0aGUgdHJhaW4gc2V0OgoKYGBge3IgcHJlZGljdF9tdWx0aW5vbWlhbH0KCnNvZnRtYXhfc2NvcmVzIDwtIHJ4UHJlZGljdChtb2RlbE9iamVjdCA9IHNvZnRtYXgsIAogICAgICAgICAgICAgICAgICAgICAgICAgICAgZGF0YSA9IHNwbGl0cyR0ZXN0LCAKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG91dERhdGEgPSB0ZW1wZmlsZShmaWxlZXh0ID0gIi54ZGYiKSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG92ZXJ3cml0ZSA9IFRSVUUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBleHRyYVZhcnNUb1dyaXRlID0gIkxhYmVsIikKCmBgYApXZSBjYW4gbWFrZSBhIGNvbmZ1c2lvbiBtYXRyaXggb2YgYWxsIG91ciByZXN1bHRzOgoKYGBge3IgY29uZl9zb2Z0bWF4fQoKcnhDdWJlKCB+IExhYmVsIDogUHJlZGljdGVkTGFiZWwgLCBkYXRhID0gc29mdG1heF9zY29yZXMsCiAgICAgICByZXR1cm5EYXRhRnJhbWUgPSBUUlVFKSAtPiBzb2Z0bWF4X3Njb3Jlc19kZgoKc29mdG1heF9zY29yZXNfZGYgJT4lIGdncGxvdChhZXMoeCA9IExhYmVsLCB5ID0gUHJlZGljdGVkTGFiZWwsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGZpbGwgPSBDb3VudHMpKSArCiAgZ2VvbV9yYXN0ZXIoKSArCiAgc2NhbGVfZmlsbF9jb250aW51b3VzKGxvdyA9ICJzdGVlbGJsdWUyIiwgaGlnaCA9ICJtZWRpdW1ibHVlIikKCmBgYAoKSGVyZSB3ZSBhcmUgcGxvdHRpbmcgdGhlIHJhdyBjb3VudHMuIFRoaXMgbWlnaHQgdW5mYWlybHkgcmVwcmVzZW50IHRoZSBtb3JlIHBvcHVsYXRlZCBjbGFzc2VzLiBMZXQncyB3ZWlnaHQgZWFjaCBjb3VudCBieSB0aGUgdG90YWwgbnVtYmVyIG9mIHNhbXBsZXMgaW4gdGhhdCBjbGFzczoKCmBgYHtyIHJhdGVzfQoKbGFiZWxfcmF0ZXMgPC0gc29mdG1heF9zY29yZXNfZGYgJT4lIAogIHRibF9kZiAlPiUgCiAgZ3JvdXBfYnkoTGFiZWwpICU+JSAKICBtdXRhdGUocmF0ZSA9IENvdW50cy9zdW0oQ291bnRzKSkKCmxhYmVsX3JhdGVzICU+JSBnZ3Bsb3QoYWVzKHggPSBMYWJlbCwgeSA9IFByZWRpY3RlZExhYmVsLCBmaWxsID0gcmF0ZSkpICsKICBnZW9tX3Jhc3RlcigpICsKICBzY2FsZV9maWxsX2NvbnRpbnVvdXMobG93ID0gInN0ZWVsYmx1ZTIiLCBoaWdoID0gIm1lZGl1bWJsdWUiKQoKCmBgYAoKTGV0J3MgZmlsbCBvdXQgYWxsIHRoZSBjb3JyZWN0IHNjb3JlcyB3aXRoIHplcm9zIHNvIHdlIGNhbiBzZWUgdGhlIGVycm9ycyBtb3JlIGNsZWFybHk6CgoKYGBge3IgZXJyb3JzfQoKbGFiZWxfcmF0ZXMgJT4lCiAgbXV0YXRlKGVycm9yX3JhdGUgPSBpZmVsc2UoTGFiZWwgPT0gUHJlZGljdGVkTGFiZWwsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgMCwgcmF0ZSkpICU+JSAKICBnZ3Bsb3QoYWVzKHggPSBMYWJlbCwgeSA9IFByZWRpY3RlZExhYmVsLCBmaWxsID0gZXJyb3JfcmF0ZSkpICsKICBnZW9tX3Jhc3RlcigpICsKICBzY2FsZV9maWxsX2NvbnRpbnVvdXMobG93ID0gInN0ZWVsYmx1ZTIiLCBoaWdoID0gIm1lZGl1bWJsdWUiLAogICAgICAgICAgICAgICAgICAgICAgICBsYWJlbHMgPSBzY2FsZXM6OnBlcmNlbnQpCgoKYGBgCgojIyBFeGVyY2lzZXMKCjEuIFRha2UgYSBsb29rIGF0IERhdmlkIFJvYmluc29uJ3MgW3R3ZWV0XShodHRwczovL3R3aXR0ZXIuY29tL2Ryb2Ivc3RhdHVzLzg2OTk5MTI0MDA5OTU0OTE4NSkgb24gdXNpbmcgYSBzaW5nbGUgcGl4ZWwgdG8gZGlzdGluZ3Vpc2ggYmV0d2VlbiBwYWlycyBvZiBkaWdpdHMuCjIuIFlvdSBjYW4gZmluZCBoaXMgW2dpc3RdKGh0dHBzOi8vZ2lzdC5naXRodWIuY29tL2RncnR3by9hYWVmOTRlY2M2YTYwY2Q1MDMyMmMwMDU0Y2MwNDQ3OCkgc2F2ZWQgaW4gdGhlIFtSc2NyaXB0cyBkaXJlY3RvcnldKC4uL1JzY3JpcHRzLzgtZHJvYi1qdXN0LWEtcGl4ZWwuUikuCg==